# -*-Python-*-
include 'models/at5_base.gin'

import mesh_tensorflow.transformer
import mesh_tensorflow.transformer.funnel_transformer

# Encoder uses own definition of "block".
encoder/make_layer_stack.block_scope = False
transformer.make_bitransformer.bitransformer_cls = @funnel_transformer.BitransformerFunnel
encoder/make_layer_stack.layer_stack_cls = @FunnelTransformerLayerStack

# Use the standard LayerStack for the decoder.
decoder/make_layer_stack.layer_stack_cls = @LayerStack

funnel_transformer.FunnelTransformerLayerStack.sublayers_initial = [
    @transformer.sublayer_dropout,
]
funnel_transformer.FunnelTransformerLayerStack.sublayers_per_layer = [
    @transformer.sublayer_rms_norm,
    @transformer.sublayer_call_layer,
    @transformer.sublayer_dropout,
    @transformer.sublayer_residual,
]
funnel_transformer.FunnelTransformerLayerStack.sublayers_final = [
    @transformer.sublayer_rms_norm,
    @transformer.sublayer_dropout,
]

# Total 9 encoder layers
FunnelTransformerLayerStack.n_blocks = 3
FunnelTransformerLayerStack.block_param_size = [4, 4, 4]
FunnelTransformerLayerStack.block_repeat_size = [1, 1, 1]

# SelfAttention and DenseReluDense
FunnelTransformerLayerStack.n_submodules = 2

num_encoder_layers = 12
num_decoder_layers = 12
encoder/transformer.make_layer_stack.num_layers = %num_encoder_layers
decoder/transformer.make_layer_stack.num_layers = %num_decoder_layers

# Packing is not supported for Funnel Transformer
mesh_train_dataset_fn.pack = False
